#################################################################
# imports, etc.

import sys
from utils import *

reform_flag=3


##################################################
# load data

suff=''
suff2='_baseline'
actual=None
probs=None
flag=0

print('\tloading simulation data')

gaps = pd.read_stata('../../20 Intermediate Files/inputs_calibration_gapelast.dta')
gaps=gaps.fillna(0)
actual = gaps.temp_b
probs = np.genfromtxt(inpath + 'tpuprobs_baseline.txt')[:-1]

if(len(sys.argv)>1 and sys.argv[1]=='-avg-sector'):
    probs = np.genfromtxt(inpath + 'tpuprobs_avg_sector.txt')
    suff='_avg_sector'
    suff2 = '_avg_sector'
elif(len(sys.argv)>1 and sys.argv[1]=='-one-sector'):
    probs = np.genfromtxt(inpath + 'tpuprobs_one_sector.txt')
    suff='_one_sector'
    suff2 = '_one_sector'
elif(len(sys.argv)>1 and sys.argv[1]=='-calvo'):
    probs = np.genfromtxt(inpath + 'tpuprobs_calvo.txt')
    suff='_calvo'
    suff2='_calvo'
elif(len(sys.argv)>1 and sys.argv[1]=='-lo-xi'):
    probs = np.genfromtxt(inpath + 'tpuprobs_lo_xi.txt')
    suff='_lo_xi'
    suff2='_lo_xi'
elif(len(sys.argv)>1 and sys.argv[1]=='-hi-xi'):
    probs = np.genfromtxt(inpath + 'tpuprobs_hi_xi.txt')
    suff='_hi_xi'
    suff2='_hi_xi'
elif(len(sys.argv)>1 and sys.argv[1]=='-lo-rhoxi'):
    probs = np.genfromtxt(inpath + 'tpuprobs_lo_rhoxi.txt')
    suff='_lo_rhoxi'
    suff2='_lo_rhoxi'
elif(len(sys.argv)>1 and sys.argv[1]=='-hi-rhoxi'):
    probs = np.genfromtxt(inpath + 'tpuprobs_hi_rhoxi.txt')
    suff='_hi_rhoxi'
    suff2='_hi_rhoxi'
elif(len(sys.argv)>1 and sys.argv[1]=='-sunkcost'):
    probs = np.genfromtxt(inpath + 'tpuprobs_sunkcost.txt')
    suff='_sunkcost'
    suff2='_sunkcost'
elif(len(sys.argv)>1 and sys.argv[1]=='-permtpu'):
    probs = np.genfromtxt(inpath + 'tpuprobs_permtpu.txt')
    suff='_permtpu'
    suff2='_permtpu'
elif(len(sys.argv)>1 and sys.argv[1]=='-ci-lower'):
    actual = gaps.temp_lb
    probs = np.genfromtxt(inpath + 'tpuprobs_ci_lower.txt')
    suff='_ci_lower'
    suff2='_ci_lower'
elif(len(sys.argv)>1 and sys.argv[1]=='-ci-upper'):
    actual = gaps.temp_ub
    probs = np.genfromtxt(inpath + 'tpuprobs_ci_upper.txt')
    suff='_ci_upper'
    suff2='_ci_upper'
    
if('-ags-sens' in sys.argv):
    flag=1
       
# load data and aggregate it
df = [load_data(1,suff,suff2),load_data(3,suff,suff2)]

#######################################
# trade elasticities

if('-calc-te' in sys.argv):
    print('\testimating trade elasticities via ECM')
    err_SR,err_LR = ecm_regression(df[1],suff2)
    caldata_ecm = np.array([err_SR,err_LR])
    np.savetxt(calpath + 'caldata_ecm'+suff2+'.txt',caldata_ecm)

#######################################
# annual gap coefficients

print('\testimating annual NNTR gap coefficients')

effects = [gap_regression_AKKRS(df_,'exports',flag) for df_ in df]

_,tmp = hpfilter(actual[7:],lamb=2)
actual2 = np.append(actual[0:8],tmp[1:])
actual2[-1]=0.0

caldata = effects[1] - actual2

np.savetxt(calpath +'caldata'+suff2+'.txt',caldata)

#############################################################################
print('\tplotting coefficients + probabilities')

years = range(1974,2009)

fig,ax = slide_fig()

ax.plot(range(1974,2009),actual,color=colors[2],
        marker='o',markersize=3,alpha=0.8,label='Raw data',linewidth=lw)

ax.plot(years,actual2,color=colors[2],marker=None,linestyle='--',
        alpha=0.8,label='Smoothed data',linewidth=lw)

ax.plot(years,effects[1],color=colors[1],
        alpha=0.8,label='TPU',linewidth=lw)

ax.plot(years,effects[0],color=colors[0],
        alpha=0.8,label='No TPU',linestyle='--',linewidth=lw)

ax.axhline(0,color='black',linestyle='-',linewidth=1,alpha=1,zorder=1)
ax.axvline(NR,color='black',linestyle=':',linewidth=1,alpha=0.7,zorder=2)
ax.axvline(NU,color='black',linestyle=':',linewidth=1,alpha=0.7,zorder=3)

ax.set_xlim(1974,2008)
ax.legend(loc='lower right',prop={'size':tw})

plt.savefig(outpath + 'gap_coefficients'+suff+'.pdf',
            bbox_inches='tight')

plt.close('all')


t = range(1973,2008)
p1 = probs.copy()
p1[NR-1-1973:]=p1[NR-1-1973]
p2 = probs.copy()
p2[0:(NR-1973)] = probs[(NR-1973)]

fig,ax = slide_fig()

ax.plot(t,p1,color=colors[0],alpha=0.7,label=r'$P(NNTR\rightarrow MFN)$',linewidth=lw)
ax.plot(t,p2,color=colors[1],alpha=0.7,label=r'$P(MFN\rightarrow NNTR)$',linewidth=lw)
ax.axvline(NR,color='black',linestyle=':',linewidth=1,alpha=0.7)
ax.axvline(NU,color='black',linestyle=':',linewidth=1,alpha=0.7)
ax.set_xlim(1973,2008)
ax.legend(loc='upper right',prop={'size':tw}) 
fig.subplots_adjust(hspace=0.2,wspace=0.25)
plt.savefig(outpath + 'probabilities'+suff+'.pdf',bbox_inches='tight')
plt.close('all')

